syntax = "proto3";
option cc_enable_arenas = true;

package tensorflow.tensorforest;

import "tensorflow/contrib/decision_trees/proto/generic_tree_model.proto";


message FertileStats {
  // Tracks stats for each node.  node_to_slot[i] is the FertileSlot for node i.
  // This may be sized to max_nodes initially, or grow dynamically as needed.
  repeated FertileSlot node_to_slot = 1;
}


message GiniStats {
  // This allows us to quickly track and calculate impurity (classification)
  //  by storing the sum of input weights and the sum of the squares of the
  // input weights.  Weighted gini is then: 1 - (square / sum * sum).
  // Updates to these numbers are:
  //   old_i = leaf->value(label)
  //   new_i = old_i + incoming_weight
  //   sum -> sum + incoming_weight
  //   square -> square - (old_i ^ 2) + (new_i ^ 2)
  //   total_left_sum -> total_left_sum - old_left_i * old_total_i +
  //                                      new_left_i * new_total_i
  float square = 2;
}

message LeafStat {
  // The sum of the weights of the training examples that we have seen.
  // This is here, outside of the leaf_stat oneof, because almost all
  // types will want it.
  float weight_sum = 3;

  // TODO(thomaswc): Move the GiniStats out of LeafStats and into something
  // that only tracks them for splits.
  message GiniImpurityClassificationStats {
    oneof counts {
      decision_trees.Vector dense_counts = 1;
      decision_trees.SparseVector sparse_counts = 2;
    }
    GiniStats gini = 3;
  }

  // This is the info needed for calculating variance for regression.
  // Variance will still have to be summed over every output, but the
  // number of outputs in regression problems is almost always 1.
  message LeastSquaresRegressionStats {
    decision_trees.Vector mean_output = 1;
    decision_trees.Vector mean_output_squares = 2;
  }

  oneof leaf_stat {
    GiniImpurityClassificationStats classification = 1;
    LeastSquaresRegressionStats regression = 2;
    // TODO(thomaswc): Add in v5's SparseClassStats.
  }
}

message FertileSlot {
  // The statistics for *all* the examples seen at this leaf.
  LeafStat leaf_stats = 4;

  repeated SplitCandidate candidates = 1;

  // The statistics for the examples seen at this leaf after all the
  // splits have been initialized.  If post_init_leaf_stats.weight_sum
  // is > 0, then all candidates have been initialized.  We need to track
  // both leaf_stats and post_init_leaf_stats because the first is used
  // to create the decision_tree::Leaf and the second is used to infer
  // the statistics for the right side of a split (given the leaf side
  // stats).
  LeafStat post_init_leaf_stats = 6;

  int32 node_id = 5;
  int32 depth = 7;
}

message SplitCandidate {
  // proto representing the potential node.
  decision_trees.BinaryNode split = 1;

  // Right counts are inferred from FertileSlot.leaf_stats and left.
  LeafStat left_stats = 4;

  // Right stats (not full counts) are kept here.
  LeafStat right_stats = 5;

  // Fields used when training with a graph runner.
  string unique_id = 6;
}

// Proto used for tracking tree paths during inference time.
message TreePath {
  // Nodes are listed in order that they were traversed. i.e. nodes_visited[0]
  // is the tree's root node.
  repeated decision_trees.TreeNode nodes_visited = 1;
}
